import logging
from pprint import pformat
from typing import Any, Dict, Sequence, Set

# common imports
from common.credentials import Password, Username
from common.event_queue import IAgentEventPublisher
from common.types import AgentID, Event, NetworkPort, NetworkService
from common.utils.code_utils import del_key

# dependencies to get rid of or internalize
from infection_monkey.exploit import (
    IAgentBinaryRepository,
    IAgentOTPProvider,
    IHTTPAgentBinaryServerRegistrar,
)
from infection_monkey.exploit.tools import (
    generate_brute_force_credentials,
    identity_type_filter,
    secret_type_filter,
)
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.network import TCPPortSelector
from infection_monkey.propagation_credentials_repository import IPropagationCredentialsRepository

from .mssql_client import MSSQLClient
from .mssql_exploit_client import MSSQLExploitClient
from .mssql_exploiter import MSSQLExploiter
from .mssql_options import MSSQLOptions

logger = logging.getLogger(__name__)


class Plugin:
    def __init__(
        self,
        *,
        plugin_name: str,
        agent_id: AgentID,
        http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
        agent_event_publisher: IAgentEventPublisher,
        agent_binary_repository: IAgentBinaryRepository,
        propagation_credentials_repository: IPropagationCredentialsRepository,
        tcp_port_selector: TCPPortSelector,
        otp_provider: IAgentOTPProvider,
        **kwargs,
    ):
        self._plugin_name = plugin_name
        self._agent_id = agent_id
        self._http_agent_binary_server_registrar = http_agent_binary_server_registrar
        self._agent_event_publisher = agent_event_publisher
        self._agent_binary_repository = agent_binary_repository
        self._propagation_credentials_repository = propagation_credentials_repository
        self._tcp_port_selector = tcp_port_selector
        self._otp_provider = otp_provider

    def run(
        self,
        *,
        host: TargetHost,
        servers: Sequence[str],
        current_depth: int,
        options: Dict[str, Any],
        interrupt: Event,
        **kwargs,
    ) -> ExploiterResult:
        # HTTP ports options are hack because they are needed in fingerprinters
        del_key(options, "http_ports")

        try:
            logger.debug(f"Parsing options: {pformat(options)}")
            mssql_options = MSSQLOptions(**options)
        except Exception as err:
            msg = f"Failed to parse MSSQL options: {err}"
            logger.exception(msg)
            return ExploiterResult(error_message=msg)

        ports_to_try = get_ports_to_try(host, mssql_options)
        if not ports_to_try:
            msg = f"Host {host.ip} has no open MSSQL ports"
            logger.debug(msg)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )

        try:
            logger.debug(f"Running MSSQL exploiter on host {host.ip}")
            brute_force_credentials = generate_brute_force_credentials(
                self._propagation_credentials_repository.get_credentials(),
                identity_filter=identity_type_filter([Username]),
                secret_filter=secret_type_filter([Password]),
            )
            mssql_client = MSSQLClient(mssql_options.server_timeout)

            mssql_exploit_client = MSSQLExploitClient(
                exploiter_name=self._plugin_name,
                agent_id=self._agent_id,
                agent_event_publisher=self._agent_event_publisher,
                mssql_client=mssql_client,
            )

            mssql_exploiter = MSSQLExploiter(
                self._agent_id,
                mssql_exploit_client,
                self._http_agent_binary_server_registrar,
                self._otp_provider,
            )
            return mssql_exploiter.exploit_host(
                host,
                mssql_options,
                servers,
                current_depth,
                brute_force_credentials,
                ports_to_try,
                interrupt,
            )
        except Exception as err:
            msg = f"An unexpected exception occurred while attempting to exploit host: {err}"
            logger.exception(msg)
            return ExploiterResult(error_message=msg)


def get_ports_to_try(host: TargetHost, options: MSSQLOptions) -> Set[NetworkPort]:
    tcp_ports = host.ports_status.tcp_ports
    ports_to_try = {port for port in set(options.target_ports) if port not in tcp_ports.closed}

    services_to_scan: Set[NetworkService] = set()
    if options.try_discovered_mssql_ports:
        services_to_scan.add(NetworkService.MSSQL)

    if options.try_unknown_service_ports:
        services_to_scan.add(NetworkService.UNKNOWN)

    additional_ports = {p for p in tcp_ports.open if tcp_ports[p].service in services_to_scan}
    ports_to_try.update(additional_ports)

    return ports_to_try
